package com.primeton.poctag.task;

import lombok.extern.slf4j.Slf4j;

import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.Semaphore;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

/**
 * <pre>
 *
 * Created by zhaopx.
 * User: zhaopx
 * Date: 2020/11/16
 * Time: 15:36
 *
 * </pre>
 *
 * @author zhaopx
 */
@Slf4j
public class Data3CSparkQueue {


    /**
     * 任务缓存
     */
    private final static Map<String, Object> TASK_QUEUE = new ConcurrentHashMap<>();


    /**
     * 任务错误队列
     */
    private final static Map<String, Object> ERROR_TASK_QUEUE = new ConcurrentHashMap<>();


    /**
     * 正在运行的队列
     */
    private final static Set<String> RUNNING_TASK_QUEUE = new HashSet<>();


    /**
     * 控制 Spark 并发的信号量
     */
    private final Semaphore semaphore;


    /**
     * 公平锁
     */
    private final static Lock LOCK = new ReentrantLock();



    private static Data3CSparkQueue SPARK_QUEUE;

    private Data3CSparkQueue(int permits) {
        semaphore = new Semaphore(permits);
    }


    /**
     * 初次调用有效，
     * @return
     */
    public static Data3CSparkQueue getInstance() {
        return getInstance(3);
    }


    /**
     * 按照配置，设置并发量。 第一次调用有效
     * @param permits
     * @return
     */
    public static synchronized Data3CSparkQueue getInstance(int permits) {
        if(SPARK_QUEUE == null) {
            SPARK_QUEUE = new Data3CSparkQueue(permits);
        }
        return SPARK_QUEUE;
    }




    /**
     * 添加任务
     * @param taskInstanceId
     * @param taskInfo
     */
    public static boolean addTask(String taskInstanceId, Map taskInfo) {
        LOCK.lock();
        try {
            if(!TASK_QUEUE.containsKey(taskInstanceId)) {
                TASK_QUEUE.put(taskInstanceId, taskInfo);
                log.info("add task: {} , params: {}", taskInstanceId, String.valueOf(taskInfo));
                return true;
            }
        } finally {
            LOCK.unlock();
        }
        return false;
    }


    /**
     * 获取当前需要执行队列的长度
     * @return
     */
    public static int getPendingTaskSize() {
        LOCK.lock();
        try {
            HashMap<String, Object> tmpMap = new HashMap<>(TASK_QUEUE);
            for (String s : RUNNING_TASK_QUEUE) {
                tmpMap.remove(s);
            }
            return tmpMap.size();
        } finally {
            LOCK.unlock();
        }
    }


    /**
     * 获取当前需要执行队列
     * @return
     */
    public static Set<String> getPendingTasks() {
        LOCK.lock();
        try {
            HashMap<String, Object> tmpMap = new HashMap<>(TASK_QUEUE);
            for (String s : RUNNING_TASK_QUEUE) {
                tmpMap.remove(s);
            }
            return tmpMap.keySet();
        } finally {
            LOCK.unlock();
        }
    }


    /**
     * 获取当前正在执行任务的长度
     * @return
     */
    public static int getRunningTaskSize() {
        return RUNNING_TASK_QUEUE.size();
    }



    public static Object getTaskInfo(String taskId) {
        return TASK_QUEUE.get(taskId);
    }


    /**
     * 移除任务
     * @param taskId
     */
    public static void removeTask(String taskId) {
        LOCK.lock();
        try {
            TASK_QUEUE.remove(taskId);
            RUNNING_TASK_QUEUE.remove(taskId);
            log.info("remove task: {}", taskId);
        } finally {
            LOCK.unlock();
        }
    }


    /**
     * 错误的任务报告
     * @param taskId
     */
    public static void reportError(String taskId) {
        LOCK.lock();
        try {
            Object errorTaskInfo = TASK_QUEUE.remove(taskId);
            ERROR_TASK_QUEUE.put(taskId, errorTaskInfo);
            RUNNING_TASK_QUEUE.remove(taskId);
        } finally {
            LOCK.unlock();
        }
    }


    /**
     * 判断任务是否正在运行
     * @param taskId
     * @return
     */
    public static boolean runningTask(String taskId) {
        return RUNNING_TASK_QUEUE.contains(taskId);
    }


    /**
     * 执行该函数
     * @param executor
     * @param task
     */
    public Map execute(ExecutorService executor, final SparkTask task) {
        final Future<?> future = executor.submit((Callable<? extends Object>) () -> {
            final String runningTaskId = task.getTaskId();
            // 有任务需要运行
            if (Data3CSparkQueue.runningTask(runningTaskId)) {
                // 取得的待运行的task，不能是正在运行的列表中的
                log.info("task {} running.", runningTaskId);
                return Collections.emptyMap();
            }
            // 获得一个许可
            try {
                semaphore.acquire();
            } catch (InterruptedException e) {
                return Collections.emptyMap();
            }
            try {
                // 运行任务
                RUNNING_TASK_QUEUE.add(runningTaskId);
                log.info("running task: {}", runningTaskId);
                final Map<String, Object> result = task.call();
                log.info("finished task: {}", runningTaskId);
                // 执行成功，移除
                removeTask(runningTaskId);
                return result;
            } catch (Exception e) {
                log.info("执行任务异常。error task: " + runningTaskId, e);
                // 运行错误
                reportError(runningTaskId);
                throw e;
            } finally {
                // 释放许可
                semaphore.release();
            }
        });

        // 获得结果
        try {
            return (Map) future.get();
        } catch (Exception e) {
            throw new IllegalStateException(e);
        }
    }
}
